1 Introduction
Build a model that can filter user comments based on the degree of language maliciousness:
- Preprocess the text by eliminating the set of tokens that do not make significant contribution at the semantic level.
- Transform the text corpus into sequences.
- Build a Deep Learning model including recurrent layers for a multilabel classification task.
- At prediction time, the model should return a vector containing a 1 or a 0 at each label in the dataset (toxic, severe_toxic, obscene, threat, insult, identity_hate). In this way, a non-harmful comment will be classified by a vector of only 0s [0,0,0,0,0]. In contrast, a dangerous comment will exhibit at least a 1 among the 6 labels.
2 Setup
Leveraging Quarto and RStudio, I will setup an R and Python enviroment.
2.1 Import R libraries
Import R libraries. These will be used for both the rendering of the document and data analysis. The reason is I prefer ggplot2 over matplotlib. I will also use colorblind safe palettes.
2.2 Import Python packages
Code
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import keras
import keras_nlp
from keras.backend import clear_session
from keras.models import Model, load_model
from keras.layers import TextVectorization, Input, Dense, Embedding, Dropout, GlobalAveragePooling1D, LSTM, Bidirectional, GlobalMaxPool1D, Flatten, Attention
from keras.metrics import Precision, Recall, AUC, SensitivityAtSpecificity, SpecificityAtSensitivity, F1Score
from sklearn.model_selection import train_test_split, KFold
from sklearn.metrics import multilabel_confusion_matrix, classification_report, ConfusionMatrixDisplay, precision_recall_curve, f1_score, recall_score, roc_auc_scoreCreate a Config class to store all the useful parameters for the model and for the project.
2.3 Class Config
I created a class with all the basic configuration of the model, to improve the readability.
Code
class Config():
def __init__(self):
self.url = "https://s3.eu-west-3.amazonaws.com/profession.ai/datasets/Filter_Toxic_Comments_dataset.csv"
self.max_tokens = 20000
self.output_sequence_length = 911 # check the analysis done to establish this value
self.embedding_dim = 128
self.batch_size = 32
self.epochs = 100
self.temp_split = 0.3
self.test_split = 0.5
self.random_state = 42
self.total_samples = 159571 # total train samples
self.train_samples = 111699
self.val_samples = 23936
self.features = 'comment_text'
self.labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
self.new_labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate', "clean"]
self.label_mapping = {label: i for i, label in enumerate(self.labels)}
self.new_label_mapping = {label: i for i, label in enumerate(self.labels)}
self.path = "/Users/simonebrazzi/R/blog/posts/toxic_comment_filter/history/f1score/"
self.model = self.path + "model_f1.keras"
self.checkpoint = self.path + "checkpoint.lstm_model_f1.keras"
self.history = self.path + "lstm_model_f1.xlsx"
self.metrics = [
Precision(name='precision'),
Recall(name='recall'),
AUC(name='auc', multi_label=True, num_labels=len(self.labels)),
F1Score(name="f1", average="macro")
]
def get_early_stopping(self):
early_stopping = keras.callbacks.EarlyStopping(
monitor="val_f1", # "val_recall",
min_delta=0.2,
patience=10,
verbose=0,
mode="max",
restore_best_weights=True,
start_from_epoch=3
)
return early_stopping
def get_model_checkpoint(self, filepath):
model_checkpoint = keras.callbacks.ModelCheckpoint(
filepath=filepath,
monitor="val_f1", # "val_recall",
verbose=0,
save_best_only=True,
save_weights_only=False,
mode="max",
save_freq="epoch"
)
return model_checkpoint
def find_optimal_threshold_cv(self, ytrue, yproba, metric, thresholds=np.arange(.05, .35, .05), n_splits=7):
# instantiate KFold
kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)
threshold_scores = []
for threshold in thresholds:
cv_scores = []
for train_index, val_index in kf.split(ytrue):
ytrue_val = ytrue[val_index]
yproba_val = yproba[val_index]
ypred_val = (yproba_val >= threshold).astype(int)
score = metric(ytrue_val, ypred_val, average="macro")
cv_scores.append(score)
mean_score = np.mean(cv_scores)
threshold_scores.append((threshold, mean_score))
# Find the threshold with the highest mean score
best_threshold, best_score = max(threshold_scores, key=lambda x: x[1])
return best_threshold, best_score
config = Config()3 Data
The dataset is accessible using tf.keras.utils.get_file to get the file from the url. N.B. For reproducibility purpose, I also downloaded the dataset. There was time in which the link was not available.
# A tibble: 5 × 8
comment_text toxic severe_toxic obscene threat insult identity_hate
<chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 "Explanation\nWhy the … 0 0 0 0 0 0
2 "D'aww! He matches thi… 0 0 0 0 0 0
3 "Hey man, I'm really n… 0 0 0 0 0 0
4 "\"\nMore\nI can't mak… 0 0 0 0 0 0
5 "You, sir, are my hero… 0 0 0 0 0 0
# ℹ 1 more variable: sum_injurious <dbl>
Lets create a clean variable for EDA purpose: I want to visually see how many observation are clean vs the others labels.
3.1 EDA
First a check on the dataset to find possible missing values and imbalances.
3.1.1 Frequency
Code
library(reticulate)
df_r <- py$df
new_labels_r <- py$config$new_labels
df_r_grouped <- df_r %>%
select(all_of(new_labels_r)) %>%
pivot_longer(
cols = all_of(new_labels_r),
names_to = "label",
values_to = "value"
) %>%
group_by(label) %>%
summarise(count = sum(value)) %>%
mutate(freq = round(count / sum(count), 4))
df_r_grouped# A tibble: 7 × 3
label count freq
<chr> <dbl> <dbl>
1 clean 143346 0.803
2 identity_hate 1405 0.0079
3 insult 7877 0.0441
4 obscene 8449 0.0473
5 severe_toxic 1595 0.0089
6 threat 478 0.0027
7 toxic 15294 0.0857
3.1.2 Barchart
Code
library(reticulate)
barchart <- df_r_grouped %>%
ggplot(aes(x = reorder(label, count), y = count, fill = label)) +
geom_col() +
labs(
x = "Labels",
y = "Count"
) +
# sort bars in descending order
scale_x_discrete(limits = df_r_grouped$label[order(df_r_grouped$count, decreasing = TRUE)]) +
scale_fill_brewer(type = "seq", palette = "RdYlBu") +
theme_minimal()
ggplotly(barchart)It is visible how much the dataset in imbalanced. This means it could be useful to check for the class weight and use this argument during the training.
It is clear that most of our text are clean. We are talking about 0.8033 of the observations which are clean. Only 0.1967 are toxic comments.
3.2 Sequence lenght definition
To convert the text in a useful input for a NN, it is necessary to use a TextVectorization layer. See the Section 4 section.
One of the method is output_sequence_length: to better define it, it is useful to analyze our text length. To simulate what the model we do, we are going to remove the punctuation and the new lines from the comments.
3.2.1 Summary
Code
# A tibble: 1 × 6
Min. `1st Qu.` Median Mean `3rd Qu.` Max.
<dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 4 91 196 378. 419 5000
3.2.2 Boxplot
Code
library(reticulate)
boxplot <- df_r %>%
mutate(
comment_text_clean = comment_text %>%
tolower() %>%
str_remove_all("[[:punct:]]") %>%
str_replace_all("\n", " "),
text_length = comment_text_clean %>% str_count()
) %>%
# pull(text_length) %>%
ggplot(aes(y = text_length)) +
geom_boxplot() +
theme_minimal()
ggplotly(boxplot)3.2.3 Histogram
Code
library(reticulate)
df_ <- df_r %>%
mutate(
comment_text_clean = comment_text %>%
tolower() %>%
str_remove_all("[[:punct:]]") %>%
str_replace_all("\n", " "),
text_length = comment_text_clean %>% str_count()
)
Q1 <- quantile(df_$text_length, 0.25)
Q3 <- quantile(df_$text_length, 0.75)
IQR <- Q3 - Q1
upper_fence <- as.integer(Q3 + 1.5 * IQR)
histogram <- df_ %>%
ggplot(aes(x = text_length)) +
geom_histogram(bins = 50) +
geom_vline(aes(xintercept = upper_fence), color = "red", linetype = "dashed", linewidth = 1) +
theme_minimal() +
xlab("Text Length") +
ylab("Frequency") +
xlim(0, max(df_$text_length, upper_fence))
ggplotly(histogram)Considering all the above analysis, I think a good starting value for the output_sequence_length is 911, the upper fence of the boxplot. In the last plot, it is the dashed red vertical line.. Doing so, we are removing the outliers, which are a small part of our dataset.
3.3 Dataset
Now we can split the dataset in 3: train, test and validation sets. Considering there is not a function in sklearn which lets split in these 3 sets, we can do the following: - split between a train and temporary set with a 0.3 split. - split the temporary set in 2 equal sized test and val sets.
Code
x = df[config.features].values
y = df[config.labels].values
xtrain, xtemp, ytrain, ytemp = train_test_split(
x,
y,
test_size=config.temp_split, # .3
random_state=config.random_state
)
xtest, xval, ytest, yval = train_test_split(
xtemp,
ytemp,
test_size=config.test_split, # .5
random_state=config.random_state
)xtrain shape: py$xtrain.shape ytrain shape: py$ytrain.shape xtest shape: py$xtest.shape ytest shape: py$ytest.shape xval shape: py$xval.shape yval shape: py$yval.shape
The datasets are created using the tf.data.Dataset function. It creates a data input pipeline. The tf.data API makes it possible to handle large amounts of data, read from different data formats, and perform complex transformations. The tf.data.Dataset is an abstraction that represents a sequence of elements, in which each element consists of one or more components. Here each dataset is creates using from_tensor_slices. It create a tf.data.Dataset from a tuple (features, labels). .batch let us work in batches to improve performance, while .prefetch overlaps the preprocessing and model execution of a training step. While the model is executing training step s, the input pipeline is reading the data for step s+1. Check the documentation for further informations.
Code
train_ds = (
tf.data.Dataset
.from_tensor_slices((xtrain, ytrain))
.shuffle(xtrain.shape[0])
.batch(config.batch_size)
.prefetch(tf.data.experimental.AUTOTUNE)
)
test_ds = (
tf.data.Dataset
.from_tensor_slices((xtest, ytest))
.batch(config.batch_size)
.prefetch(tf.data.experimental.AUTOTUNE)
)
val_ds = (
tf.data.Dataset
.from_tensor_slices((xval, yval))
.batch(config.batch_size)
.prefetch(tf.data.experimental.AUTOTUNE)
)Code
train_ds cardinality: 3491
val_ds cardinality: 748
test_ds cardinality: 748
Check the first element of the dataset to be sure that the preprocessing is done correctly.
(array([b'"\n\n Thanks \n\nThanks for the heads up regarding the three revert rule, which I was already familiar with. Since you feel so strongly about this topic (calling the iPod ""iPod"") you should contribute to the discussion on the articles talk page.\nThanks,\n "',
b'"\n\n Please stop. If you continue to violate Wikipedia\'s no original research policy by adding your personal analysis or synthesis into articles, you will be blocked from editing Wikipedia. \xe2\x80\x94\xe2\x80\x94 \nIf this is a shared IP address, and you didn\'t make any unconstructive edits, consider creating an account for yourself so you can avoid further irrelevant warnings."',
b"Please leave \nYou have been abusing your right as an administrator when one writes something that you disagree with. It's about them you step now.",
b"Yes, but I also have made numerous constructive edits, if you haven't noticed. And besides, prove to me that Hitler's middle name wasn't Basil. I could be deadly serious for all you know!",
b'Yeah, I was thinking the same thing. \xe2\x80\x93',
b'June 2014 (UTC)\n\nThanks. If you want to create a submission to Articles for Creation, then feel free to do do. It would be much appreciated. 01:09, 6',
b'I think my current edits are a fair compromise and keep the flow of the article fairly consistent. The only portion I removed was that of the lds primary because it has no point here. The other statements were moved and the section heading was removed. Please comment here and let me know what you think to come to a concensus and end the edit war.',
b'Removing warnings is considered anger that other people type bull**** on my discussion page. Just for your info.',
b'The total population(10 millions) does not match with the e sum of the referenced populationsbelow in the infoboxImages in Infobox',
b"I will ask Giorgos, but I don't think that he will have something good to say. You have being proven to be the number one personal attacker of the team, that is trying to enforce their agenda on anything related with Cyprus. You could have done otherwise, yet not only you let this team to make Cyprus to look, like the worst criminals house, in Wikipedia, but as I see you are leading it as well. What a shame.... By the way, if you think that in this country we will not do anything, and we are going to let you and a few other propagandists portray as as 100 time worst that we are, you are mistaken.",
b"Weisz \n\nI'm touched by your faith in the accuracy of director's dates of birth as recorded by Companies House. IME, an article by 'some journalist' is of equal reliability to the records of Companies House. I'm not going to revert your change again, but I suspect that others may do so. Remember, it's not our place to determine what is correct when there is doubt as to a validity of a fact - we should cite both sources and let the reader determine which one they trust more. YMMV.",
b", your buttocks don't seem constructive to me.",
b"You're mad?. Let me complete this article to this sunday, stop this madness with references, images, etc. Let me complete the article please. 190.242.99.226",
b"Hello! \n\nHello! Please do NOT, edit my profile without reasons. I clearly put UCK and I wrote albanians because UCK is made of albanians. First of all I made many, many contributions that stopped separatism and nazism spreading here and you banned me for what? Hate speech? When I spoke to that member he asked me something, I answered what I know is truth, and you banned me? It's like we talk about rock music and you like rock and I don't then you ban me for 2 weeks and say other reason. Please tell me why did you do this? I am going to give contributions like I did to Brigandine, Resident Evil 3 etc. but tell me why did you ban me? I don't know how much you like albanians I have right to believe in truth, and while I'm not insulting anyone with my userboxes you can't delete them just because you are an administrator. Please, answer this time with an answer not ban. Good luck!",
b'Wikipedia talk:Requests for arbitration/West Bank - Judea and Samaria.',
b'TODAY SHOW HAS BEEN NUMBERf 1 FOR 15 YEARS. MATT HAS BEEN ON THE SHO SINCE 1994 AND NATALIE HAS BEEN ON SINCE 2003. \n\n Please unblock me. \n\n{unblock|reason=Your reason here }} please unblock me. I know I am right.',
b'"\n\n Image:Tax court.gif \n\nSeriously, look at the image in the upper right hand corner of the webpage linked from the image page - it is indeed animated. Cheers! T "',
b'"\n{| style=""background-color:#F5FFFA; padding:0;"" cellpadding=""0""\n|style=""border:1px solid #084080; background-color:#F5FFFA; vertical-align:top; color:#000000;""|\n Hello, ! Welcome to Wikipedia! Thank you for your contributions to this free encyclopedia. If you decide that you need help, check out Getting Help below, ask me on , or place on your talk page and ask your question there. Please remember to sign your name on talk pages by clicking or using four tildes (~~~~); this will automatically produce your username and the date. Finally, please do your best to always fill in the edit summary field. Below are some useful links to facilitate your involvement. Happy editing! \xe2\x80\x94 tizzle \n{| width=""100%"" style=""background-color:#F5FFFA;""\n|style=""width: 55%; border:1px solid #FFFFFF; background-color:#F5FFFA; vertical-align:top""|\n Getting started A tutorial \xe2\x80\xa2 Our five pillars \xe2\x80\xa2 Getting mentored\n How to: edit a page \xe2\x80\xa2 upload and use images Getting help Frequently asked questions \xe2\x80\xa2 Tips\n Where to ask questions or make comments\n Request administrator attention Policies and guidelines Neutral point of view \xe2\x80\xa2 No original research \n Verifiability \xe2\x80\xa2 Reliable sources \xe2\x80\xa2 Citing sources\n What Wikipedia is not \xe2\x80\xa2 Biographies of living persons\n\n Manual of Style \xe2\x80\xa2 Three-revert rule \xe2\x80\xa2 Sock puppetry\n Copyrights \xe2\x80\xa2 Policy for non-free content \xe2\x80\xa2 Image use policy\n External links \xe2\x80\xa2 Spam \xe2\x80\xa2 Vandalism\n Deletion policy \xe2\x80\xa2 Conflict of interest \xe2\x80\xa2 Notability\n|class=""MainPageBG"" style=""width: 55%; border:1px solid #FFFFFF; background-color:#F5FFFA; vertical-align:top""|\n{| width=""100%"" cellpadding=""0"" cellspacing=""5"" style=""vertical-align:top; background-color:#F5FFFA""\n! The community\n|-\n|style=""color:#000""|\n Build consensus \xe2\x80\xa2 Resolve disputes\n Assume good faith \xe2\x80\xa2 Civility \xe2\x80\xa2 Etiquette\n No personal attacks \xe2\x80\xa2 [[Wikipedia:No legal threats|No legal t',
b'"\n\n Talk:Revenue Tariff Party \n\nShould have discussed first. A reliable source though. | Talk "',
b'"\n\nI changed my mind. I remembered I had this audio converter on ym comp, and so I went and coverted a 320kbps CBR MP3 version of ""Let Me Give the World to You"" straight to OGG format, and uploaded it. You can find it as Image:Smashing Pumpkins - CR-04 - 06 - Let Me Give the World to You.ogg. I chose Let Me Give the World to You"" because apparently ""Had the album been crystallized as an ""official"" pressed release, the song was considered as a single"". It does not ""sound really ass"" to me, so lemme know what you think about it, if I have the correct licence for it, etc. X "',
b'AGAIN YOU SLANDER HIS NAME! \n\nI get strange message that I cannot attack user and must attack content! I am attacking the content, the content of a wrong that was unjustly delivered upon one of my Brotherhood members. LEITMOTIV.',
b'put bad words in wikipedia',
b'"\n\n Thanks... \n\n...for catching this one. Keep up the good work. \'\'\'rolls\'\'\' "',
b'Hello world, how does this work?',
b'Fuck off fatty. Get a life.',
b'your retarded your retardedyour retardedyour retarded your retardedyour retardedyour retarded your retardedyour retardedyour retarded your retardedyour retardedyour retarded your retardedyour retarded \n\nyour retarded your retardedyour retardedyour retarded your retardedyour retardedyour retarded your retardedyour retardedyour retarded your retardedyour retardedyour retarded your retardedyour retardedyour retarded your retardedyour retarded your retarded your retardedyour retardedyour retarded your retardedyour retardedyour retarded your retardedyour retarded',
b'"\n\nThis article should be deleted. Taiwan Province has always been controlled by the Republic of China. The PRC has never set foot in Taiwan. Therefore, it is completely illegitimate and unreasonable to have a ""Taiwan Province, PRC"" when it doesn\'t exist. It is only a fantasy made up by the Communist government. "',
b'Note: Sockpuppet of several other accounts. Contact me before unblocking.',
b"I'M A GOOFY GOOBER YEAH POOOOOOOOOOOOOO COMES OUT MY BUTT> \n\nblahhhhhhhhhhhblahhhhhhhhhblahblahblah. grrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrr. growllllllllll. rawrrrrrrr shkiyss. grrrrOWLSSSSSSSSSSS. PPPPP\nyeeeeeeeaaahhhhhBOiiiiii. i just personally attackedd you. WHAT NOWWW!",
b'"I don\'t see how the fact that he was not with the band by the time ""Radio Star"" was released supports that he was ""never a member"". Unless I\'ve missed something, being mentioned at Allmusic is not a requirement. "',
b"SatuSuro User page\nHas been dogs breakfast too long - I have hundreds if not thousands west coast tas photos ...that I never even printed in the 70's!... that I am planning to upload. Gnang has been very helpful with showing the workings of commonist - have to choose my times as the boys are over enthusiastic users of world of warcraft and its holidays... sigh SatuSuro",
b'"\n\ndave thomas\njust to tell you that i\'ll be fair with you if you\'ll be fair with me!\n\nThank you. \xe2\x9e\xa8 "'],
dtype=object), array([[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 1, 1],
[1, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]]))
And we check also the shape. We expect a feature of shape (batch, ) and a target of shape (batch, number of labels).
Code
text train shape: (32,)
text train type: object
label train shape: (32, 6)
label train type: int64
4 Preprocessing
Of course preprocessing! Text is not the type of input a NN can handle. The TextVectorization layer is meant to handle natural language inputs. The processing of each example contains the following steps: 1. Standardize each example (usually lowercasing + punctuation stripping) 2. Split each example into substrings (usually words) 3. Recombine substrings into tokens (usually ngrams) 4. Index tokens (associate a unique int value with each token) 5. Transform each example using this index, either into a vector of ints or a dense float vector.
For more reference, see the documentation at the following link.
Code
text_vectorization = TextVectorization(
max_tokens=config.max_tokens,
standardize="lower_and_strip_punctuation",
split="whitespace",
output_mode="int",
output_sequence_length=config.output_sequence_length,
pad_to_max_tokens=True
)
# prepare a dataset that only yields raw text inputs (no labels)
text_train_ds = train_ds.map(lambda x, y: x)
# adapt the text vectorization layer to the text data to index the dataset vocabulary
text_vectorization.adapt(text_train_ds)This layer is set to: - max_tokens: 20000. It is common for text classification. It is the maximum size of the vocabulary for this layer. - output_sequence_length: 911. See Figure 3 for the reason why. Only valid in "int" mode. - output_mode: outputs integer indices, one integer index per split string token. When output_mode == “int”, 0 is reserved for masked locations; this reduces the vocab size to max_tokens - 2 instead of max_tokens - 1. - standardize: "lower_and_strip_punctuation". - split: on whitespace.
To preserve the original comments as text and also have a tf.data.Dataset in which the text is preprocessed by the TextVectorization function, it is possible to map it to the features of each dataset.
Code
processed_train_ds = train_ds.map(
lambda x, y: (text_vectorization(x), y),
num_parallel_calls=tf.data.experimental.AUTOTUNE
)
processed_val_ds = val_ds.map(
lambda x, y: (text_vectorization(x), y),
num_parallel_calls=tf.data.experimental.AUTOTUNE
)
processed_test_ds = test_ds.map(
lambda x, y: (text_vectorization(x), y),
num_parallel_calls=tf.data.experimental.AUTOTUNE
)5 Model
5.1 Definition
Define the model using the Functional API.
Code
def get_deeper_lstm_model():
clear_session()
inputs = Input(shape=(None,), dtype=tf.int64, name="inputs")
embedding = Embedding(
input_dim=config.max_tokens,
output_dim=config.embedding_dim,
mask_zero=True,
name="embedding"
)(inputs)
x = Bidirectional(LSTM(256, return_sequences=True, name="bilstm_1"))(embedding)
x = Bidirectional(LSTM(128, return_sequences=True, name="bilstm_2"))(x)
# Global average pooling
x = GlobalAveragePooling1D()(x)
# Add regularization
x = Dropout(0.3)(x)
x = Dense(64, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(0.01))(x)
x = LayerNormalization()(x)
outputs = Dense(len(config.labels), activation='sigmoid', name="outputs")(x)
model = Model(inputs, outputs)
model.compile(optimizer='adam', loss="binary_crossentropy", metrics=config.metrics, steps_per_execution=32)
return model
lstm_model = get_deeper_lstm_model()
lstm_model.summary()5.2 Callbacks
Finally, the model has been trained using 2 callbacks: - Early Stopping, to avoid to consume the kaggle GPU time. - Model Checkpoint, to retrieve the best model training information.
5.3 Final preparation before fit
Considering the dataset is imbalanced, to increase the performance we need to calculate the class weight. This will be passed during the training of the model.
class_weight
toxic 0.095900590
severe_toxic 0.009928468
obscene 0.052757858
threat 0.003061800
insult 0.049132042
identity_hate 0.008710911
It is also useful to define the steps per epoch for train and validation dataset. This step is required to avoid to not consume entirely the dataset during the fit, which happened to me.
5.4 Fit
The fit has been done on Kaggle to levarage the GPU. Some considerations about the model:
-
.repeat()ensure the model sees all the dataset. -
epocsis set to 100. -
validation_datahas the same repeat. -
callbacksare the one defined before. -
class_weightensure the model is trained using the frequency of each class, because our dataset is imbalanced. -
steps_per_epochandvalidation_stepsdepend on the use ofrepeat.
Now we can import the model and the history trained on Kaggle.
5.5 Evaluate
Code
# A tibble: 5 × 2
metric value
<chr> <dbl>
1 loss 0.0542
2 precision 0.789
3 recall 0.671
4 auc 0.957
5 f1_score 0.0293
5.6 Predict
For the prediction, the model does not need to repeat the dataset, because it has already been trained on all of the train data. Now it has just to consume the new data to make the prediction.
5.7 Confusion Matrix
The best way to assess the performance of a multi label classification is using a confusion matrix. Sklearn has a specific function to create a multi label classification matrix to handle the fact that there could be multiple labels for one prediction.
5.7.1 Grid Search Cross Validation for best threshold
Grid Search CV is a technique for fine-tuning hyperparameter of a ML model. It systematically search through a set of hyperparamenter values to find the combination which led to the best model performance. In this case, I am using a KFold Cross Validation is a resempling technique to split the data into k consecutive folds. Each fold is used once as a validation while the k - 1 remaining folds are the training set. See the documentation for more information.
The model is trained to optimize the recall. The decision was made because the cost of missing a True Positive is greater than a False Positive. In this case, missing a injurious observation is worst than classifying a clean one as bad.
5.7.2 Confidence threshold and Precision-Recall trade off
Whilst the KFold GDCV technique is usefull to test multiple hyperparameter, it is important to understand the problem we are facing. A multi label deep learning classifier outputs a vector of per-class probabilities. These need to be converted to a binary vector using a confidence threshold.
- The higher the threshold, the less classes the model predicts, increasing model confidence [higher Precision] and increasing missed classes [lower Recall].
- The lower the threshold, the more classes the model predicts, decreasing model confidence [lower Precision] and decreasing missed classes [higher Recall].
Threshold selection mean we have to decide which metric to prioritize, based on the problem we are facing and the relative cost of misduging. We can consider the toxic comment filtering a problem similiar to cancer diagnostic. It is better to predict cancer in people who do not have it [False Positive] and perform further analysis than do not predict cancer when the patient has the disease [False Negative].
I decide to train the model on the F1 score to have a balanced model in both precision and recall and leave to the threshold selection to increase the recall performance.
Moreover, the model has been trained on the macro avarage F1 score, which is a single performance indicator obtained by the mean of the Precision and Recall scores of individual classses.
\[ F1\ macro\ avg = \frac{\sum_{i=1}^{n} F1_i}{n} \]
It is useful with imbalanced classes, because it weights each classes equally. It is not influenced by the number of samples of each classes. This is sette both in the config.metrics and find_optimal_threshold_cv.
f1_score
Code
Optimal threshold: 0.15000000000000002
Best score: 0.4788653077945807
Optimal threshold f1 score: 0.15. Best score: 0.4788653.
recall_score
Code
Optimal threshold recall: 0.05. Best score: 0.8095814.
roc_auc_score
Code
Optimal threshold: 0.05
Best score: 0.8809499649742268
Optimal threshold roc: 0.05. Best score: 0.88095.
5.7.3 Confusion Matrix Plot
Code
# convert probability predictions to predictions
ypred = predictions >= optimal_threshold_recall # .05
ypred = ypred.astype(int)
# create a plot with 3 by 2 subplots
fig, axes = plt.subplots(3, 2, figsize=(15, 15))
axes = axes.flatten()
mcm = multilabel_confusion_matrix(ytrue, ypred)
# plot the confusion matrices for each label
for i, (cm, label) in enumerate(zip(mcm, config.labels)):
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot(ax=axes[i], colorbar=False)
axes[i].set_title(f"Confusion matrix for label: {label}")
plt.tight_layout()
plt.show()5.8 Classification Report
Code
# A tibble: 10 × 5
metrics precision recall `f1-score` support
<chr> <dbl> <dbl> <dbl> <dbl>
1 toxic 0.552 0.890 0.682 2262
2 severe_toxic 0.236 0.917 0.375 240
3 obscene 0.550 0.936 0.692 1263
4 threat 0.0366 0.493 0.0681 69
5 insult 0.471 0.915 0.622 1170
6 identity_hate 0.116 0.720 0.200 207
7 micro avg 0.416 0.896 0.569 5211
8 macro avg 0.327 0.812 0.440 5211
9 weighted avg 0.495 0.896 0.629 5211
10 samples avg 0.0502 0.0848 0.0597 5211
6 Conclusions
The BiLSTM model is optimized to have an high recall is performing good enough to make predictions for each label. Considering the low support for the threat label, the performance is not bad. See Table 2 and Figure 1: the threat label is only 0.27 % of the observations. The model has been optimized for recall because the cost of not identifying a injurious comment as such is higher than the cost of considering a clean comment as injurious.
Possibile improvements could be to increase the number of observations, expecially for the threat one. In general there are too many clean comments. This could be avoided doing an undersampling of the clean comment, which I explicitly avoided to check the performance on the BiLSTM with an imbalanced dataset, leveraging the class weight method.